Skip to content

Conversation

JohannesGaessler
Copy link
Collaborator

@JohannesGaessler JohannesGaessler commented Sep 14, 2025

This PR fixes a bug in the CUDA FlashAttention occupancy calculation. In rare cases too few kernels would be launched in parallel, leading to a few % less performance.

This PR also delivers what I think will be the last round of performance optimizations for the tile FA kernel: I revised the memory layout to consistently copy data in 8/16 byte chunks and delayed writing the KQ accumulators to shared memory after they have been compressed to FP16. I looked up the amount of shared memory on each AMD GPU and fit the tile sizes accordingly. One thing that could still be done is do the same GQA optimization as for the mma kernel but because the GPUs using the tile kernel are comparatively slower reducing the mask I/O has little impact; it could improve performance for small batch sizes > 1 though.

Performance changes
GPU Model Microbatch size Test t/s master t/s 4a60861 Speedup
MI60 / MI50 gemma 2B Q4_0 16 pp16384 715.37 723.22 1.01
MI60 / MI50 gemma 2B Q4_0 32 pp16384 911.50 922.04 1.01
MI60 / MI50 gemma 2B Q4_0 64 pp16384 1002.59 1037.99 1.04
MI60 / MI50 gemma 2B Q4_0 128 pp16384 1571.26 1632.54 1.04
MI60 / MI50 gemma 2B Q4_0 256 pp16384 1960.46 2104.29 1.07
MI60 / MI50 gemma 2B Q4_0 512 pp16384 2137.20 2309.96 1.08
MI60 / MI50 gemma 2B Q4_0 1024 pp16384 2282.16 2504.71 1.10
MI60 / MI50 gemma 2B Q4_0 2048 pp16384 2302.98 2529.14 1.10
MI60 / MI50 gemma 2B Q4_0 4096 pp16384 2264.73 2503.34 1.11
MI60 / MI50 gemma 2B Q4_0 8192 pp16384 2201.83 2421.21 1.10
MI60 / MI50 gemma 2B Q4_0 16384 pp16384 1982.16 2153.51 1.09
MI60 / MI50 llama 1B Q4_0 16 pp16384 997.74 1063.29 1.07
MI60 / MI50 llama 1B Q4_0 32 pp16384 1330.00 1330.98 1.00
MI60 / MI50 llama 1B Q4_0 64 pp16384 1503.46 1581.35 1.05
MI60 / MI50 llama 1B Q4_0 128 pp16384 2132.19 2290.14 1.07
MI60 / MI50 llama 1B Q4_0 256 pp16384 2611.92 2828.35 1.08
MI60 / MI50 llama 1B Q4_0 512 pp16384 2901.14 3190.54 1.10
MI60 / MI50 llama 1B Q4_0 1024 pp16384 3039.48 3322.31 1.09
MI60 / MI50 llama 1B Q4_0 2048 pp16384 3098.75 3421.87 1.10
MI60 / MI50 llama 1B Q4_0 4096 pp16384 3055.54 3380.00 1.11
MI60 / MI50 llama 1B Q4_0 8192 pp16384 2892.59 3216.54 1.11
MI60 / MI50 llama 1B Q4_0 16384 pp16384 2479.57 2711.34 1.09
MI60 / MI50 llama 8B Q4_0 16 pp16384 293.87 309.78 1.05
MI60 / MI50 llama 8B Q4_0 32 pp16384 367.71 397.84 1.08
MI60 / MI50 llama 8B Q4_0 64 pp16384 402.38 432.39 1.07
MI60 / MI50 llama 8B Q4_0 128 pp16384 499.53 552.83 1.11
MI60 / MI50 llama 8B Q4_0 256 pp16384 556.47 627.75 1.13
MI60 / MI50 llama 8B Q4_0 512 pp16384 601.43 687.65 1.14
MI60 / MI50 llama 8B Q4_0 1024 pp16384 547.37 703.20 1.28
MI60 / MI50 llama 8B Q4_0 2048 pp16384 461.60 707.96 1.53
MI60 / MI50 llama 8B Q4_0 4096 pp16384 421.46 707.79 1.68
MI60 / MI50 llama 8B Q4_0 8192 pp16384 407.40 702.20 1.72
MI60 / MI50 llama 8B Q4_0 16384 pp16384 390.19 682.75 1.75
RX 6800 gemma 2B Q4_0 16 pp16384 637.07 658.96 1.03
RX 6800 gemma 2B Q4_0 32 pp16384 993.77 1005.25 1.01
RX 6800 gemma 2B Q4_0 64 pp16384 1265.65 1281.32 1.01
RX 6800 gemma 2B Q4_0 128 pp16384 1516.82 1540.82 1.02
RX 6800 gemma 2B Q4_0 256 pp16384 1726.39 1752.65 1.02
RX 6800 gemma 2B Q4_0 512 pp16384 1900.37 1927.77 1.01
RX 6800 gemma 2B Q4_0 1024 pp16384 1962.85 1985.40 1.01
RX 6800 gemma 2B Q4_0 2048 pp16384 2007.33 2030.19 1.01
RX 6800 gemma 2B Q4_0 4096 pp16384 2026.98 2051.96 1.01
RX 6800 gemma 2B Q4_0 8192 pp16384 1979.52 2000.75 1.01
RX 6800 llama 1B Q4_0 16 pp16384 903.33 943.03 1.04
RX 6800 llama 1B Q4_0 32 pp16384 1338.84 1315.65 0.98
RX 6800 llama 1B Q4_0 64 pp16384 1668.50 1707.08 1.02
RX 6800 llama 1B Q4_0 128 pp16384 1976.71 2049.74 1.04
RX 6800 llama 1B Q4_0 256 pp16384 2197.42 2369.02 1.08
RX 6800 llama 1B Q4_0 512 pp16384 2305.52 2511.15 1.09
RX 6800 llama 1B Q4_0 1024 pp16384 2442.99 2606.40 1.07
RX 6800 llama 1B Q4_0 2048 pp16384 2475.44 2629.51 1.06
RX 6800 llama 1B Q4_0 4096 pp16384 2469.49 2637.45 1.07
RX 6800 llama 1B Q4_0 8192 pp16384 2370.86 2493.89 1.05
RX 6800 llama 1B Q4_0 16384 pp16384 2076.30 2176.14 1.05
RX 6800 llama 8B Q4_0 16 pp16384 234.63 243.61 1.04
RX 6800 llama 8B Q4_0 32 pp16384 328.40 351.80 1.07
RX 6800 llama 8B Q4_0 64 pp16384 385.34 431.51 1.12
RX 6800 llama 8B Q4_0 128 pp16384 462.65 514.19 1.11
RX 6800 llama 8B Q4_0 256 pp16384 509.90 580.30 1.14
RX 6800 llama 8B Q4_0 512 pp16384 532.11 613.16 1.15
RX 6800 llama 8B Q4_0 1024 pp16384 536.30 622.59 1.16
RX 6800 llama 8B Q4_0 2048 pp16384 526.41 636.03 1.21
RX 6800 llama 8B Q4_0 4096 pp16384 520.42 637.78 1.23
RX 6800 llama 8B Q4_0 8192 pp16384 514.03 632.39 1.23
P40 gemma 2B Q4_0 16 pp16384 797.54 834.28 1.05
P40 gemma 2B Q4_0 32 pp16384 1134.80 1179.20 1.04
P40 gemma 2B Q4_0 64 pp16384 1348.98 1421.51 1.05
P40 gemma 2B Q4_0 128 pp16384 1469.82 1565.21 1.06
P40 gemma 2B Q4_0 256 pp16384 1555.06 1669.10 1.07
P40 gemma 2B Q4_0 512 pp16384 1600.38 1716.29 1.07
P40 gemma 2B Q4_0 1024 pp16384 1663.95 1792.98 1.08
P40 gemma 2B Q4_0 2048 pp16384 1663.02 1832.83 1.10
P40 gemma 2B Q4_0 4096 pp16384 1671.74 1834.08 1.10
P40 gemma 2B Q4_0 8192 pp16384 1632.67 1793.12 1.10
P40 gemma 2B Q4_0 16384 pp16384 1499.38 1637.84 1.09
P40 llama 1B Q4_0 16 pp16384 1219.33 1211.62 0.99
P40 llama 1B Q4_0 32 pp16384 1712.42 1746.36 1.02
P40 llama 1B Q4_0 64 pp16384 2017.76 2056.77 1.02
P40 llama 1B Q4_0 128 pp16384 2230.04 2277.90 1.02
P40 llama 1B Q4_0 256 pp16384 2434.15 2490.01 1.02
P40 llama 1B Q4_0 512 pp16384 2495.92 2550.39 1.02
P40 llama 1B Q4_0 1024 pp16384 2572.93 2660.26 1.03
P40 llama 1B Q4_0 2048 pp16384 2622.93 2689.04 1.03
P40 llama 1B Q4_0 4096 pp16384 2614.92 2676.44 1.02
P40 llama 1B Q4_0 8192 pp16384 2528.01 2584.50 1.02
P40 llama 1B Q4_0 16384 pp16384 2224.16 2272.10 1.02
P40 llama 8B Q4_0 16 pp16384 295.74 299.10 1.01
P40 llama 8B Q4_0 32 pp16384 357.08 363.73 1.02
P40 llama 8B Q4_0 64 pp16384 423.12 432.09 1.02
P40 llama 8B Q4_0 128 pp16384 458.12 466.13 1.02
P40 llama 8B Q4_0 256 pp16384 490.97 498.69 1.02
P40 llama 8B Q4_0 512 pp16384 501.94 510.59 1.02
P40 llama 8B Q4_0 1024 pp16384 513.06 524.31 1.02
P40 llama 8B Q4_0 2048 pp16384 519.36 530.47 1.02
P40 llama 8B Q4_0 4096 pp16384 517.18 527.73 1.02
P40 llama 8B Q4_0 8192 pp16384 514.79 524.44 1.02
P40 llama 8B Q4_0 16384 pp16384 502.04 511.36 1.02

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Sep 14, 2025
@IMbackK
Copy link
Collaborator

IMbackK commented Sep 15, 2025

I'll take a proper look at this, but will not be able to do so until the 17th

@JohannesGaessler
Copy link
Collaborator Author

Since you're already here, do you have an opinion on whether the HIP backend should be compiled with -ffast-math?

@IMbackK
Copy link
Collaborator

IMbackK commented Sep 17, 2025

A quick grep suggests we use inf directly (see softmax), so blanket ffast-math is out. we could use some of the ffast-math flags or use ffast-math on a per function or per translation unit basis, but im not sure its worth it. In the past on other code llvm fast-math has made things slower on hip for some reason, and before rocm 6.1 there where some bugs i have encountered where fast-math just generated plain wrong code.

Copy link
Collaborator

@IMbackK IMbackK left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes look fine to me, i can also confirm the performance delta on gfx1030. When making gfx908 use this code path i cant reproduce the same magnitude of performance improvement as @JohannesGaessler dose on gfx906, but find no regression. I noticed that this pr reduced the amount of spilled vgprs (altho some instances still spill like _ZL15flash_attn_tileILi64ELi32ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiiiiiiiiiiiiiliiliiiiil) so its possible that some of the extra improvement on gfx906 comes from reduced spills to scratch, where gfx908 can spill to agprs which has a lower performance impact.

Side note:
some of the vector fattn kernels spill to high heaven:

Function Name: _ZL22flash_attn_vec_ext_f32ILi128ELi8EL9ggml_type8ELS0_8ELb1EEvPKcS2_S2_S2_S2_PKiPfP15HIP_vector_typeIfLj2EEffffjfiiiiiiiiiiiiiliiliiiiil
     TotalSGPRs: 88
     VGPRs: 63
     AGPRs: 64
     ScratchSize [bytes/lane]: 1100
     Dynamic Stack: False
     Occupancy [waves/SIMD]: 4
     SGPRs Spill: 0
     VGPRs Spill: 298
     LDS Size [bytes/block]: 10240

thats 362 vector registers spilled in this kernel, as the AGPRs are also spills in this case.

@JohannesGaessler JohannesGaessler merged commit c959b67 into ggml-org:master Sep 17, 2025
46 of 48 checks passed
@JohannesGaessler
Copy link
Collaborator Author

One of my current efforts is to make the kernel parameters more configurable as a function of hardware. I intend to soon procure an RDNA4 GPU so that I can implement support for the AMD WMMA instructions in the mma FA kernel. In principle, if the mma kernel can be made to work it should perform best since you need to hold fewer registers than the tile kernel and unlike the WMMA kernel you don't have to go through shared memory. Can you give me a list of the AMD hardware that you have so that I can adjust my purchases for wider coverage?

@IMbackK
Copy link
Collaborator

IMbackK commented Sep 17, 2025

Sure, i have gfx803 (Fiji / GCN3), gfx900 (Vega APU / GCN5), gfx906 (MI50 / GCN5.1), gfx908 (MI100 / CDNA1), gfx1030 (RX6800XT / RDNA2).

I dont have any WMMA device at all, so any device with WMMA instructions would be very helpful. I know you dont intend to buy anything for actual use but from a practical perspective the large register file RDNA3 gpus (7900xtx, 7900xt, 7800xt) tend to be better for ai inference than RDNA4, just on account of being bigger devices with more CUs, vram and bandwith.

@broadbit-hu
Copy link

@IMbackK I have 7800xt, 7900xt, 7900xtx cards, how can I help you?

@IMbackK
Copy link
Collaborator

IMbackK commented Sep 17, 2025

@broadbit-hu not atm. For regression testing it is useful to have people around who regularly run llamacpp on a given arch. But we where talking about doing feature development. When doing feature development the dev in question really needs to have the device with the instructions to be implemented on hand in one of his machines.

Nexesenex added a commit to Nexesenex/croco.cpp that referenced this pull request Sep 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants